{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JyWknuJRdhK8"
      },
      "source": [
        "# MPLP: Sinusoid 5 step, 10 shot\n",
        "\n",
        "\n",
        "Copyright 2020 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "c6GXN11wdbvU"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import io\n",
        "import numpy as np\n",
        "import glob\n",
        "\n",
        "import tensorflow.compat.v2 as tf\n",
        "tf.enable_v2_behavior()\n",
        "tf.get_logger().setLevel('ERROR')\n",
        "\n",
        "import matplotlib.pyplot as plt # visualization\n",
        "\n",
        "from collections import defaultdict\n",
        "import random\n",
        "\n",
        "import itertools\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow.compat.v2 as tf\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import IPython.display as display\n",
        "from IPython.display import clear_output\n",
        "\n",
        "from PIL import Image\n",
        "import numpy as np\n",
        "import os\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "QBo8SVVqmLGU"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade -e git+https://github.com/google-research/self-organising-systems.git#egg=mplp\u0026subdirectory=mplp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cAm7UBfbjJz6"
      },
      "outputs": [],
      "source": [
        "# symlink for saved models.\n",
        "!ln -s src/mplp/mplp/savedmodels savedmodels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "VdIxbXAYgI5X"
      },
      "outputs": [],
      "source": [
        "\n",
        "from mplp.tf_layers import MPDense\n",
        "from mplp.tf_layers import MPActivation\n",
        "from mplp.tf_layers import MPSoftmax\n",
        "from mplp.tf_layers import MPL2Loss\n",
        "from mplp.tf_layers import MPNetwork\n",
        "from mplp.sinusoidals import SinusoidalsDS\n",
        "from mplp.util import SamplePool\n",
        "from mplp.training import TrainingRegime\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qKEOL1bJp3XQ"
      },
      "source": [
        "The task is to fit sinusoidals from randomly initialized networks.\n",
        "\n",
        "Therefore, there are:\n",
        "* Outer batch size = 4 number of tasks at every step. Each has a different network, different amplitude and different phase.\n",
        "* Inner batch size = 10 number of examples for each forward/backward steps.\n",
        "* num steps = 5, number of inner steps the network has to get better.\n",
        "* train/eval split: the network only sees train instances during forward/backward. The meta-learning regime *may* choose to use eval splits as well, MAML-style."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "Yx3ubhzRgPHh"
      },
      "outputs": [],
      "source": [
        "# @title create dataset and plot it\n",
        "OUTER_BATCH_SIZE = 4\n",
        "INNER_BATCH_SIZE = 10\n",
        "NUM_STEPS = 5\n",
        "\n",
        "ds_factory = SinusoidalsDS()\n",
        "\n",
        "ds = ds_factory.create_ds(OUTER_BATCH_SIZE, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "ds_iter = iter(ds)\n",
        "\n",
        "# Utility range\n",
        "xrange_inputs = np.linspace(-5,5,100).reshape((100, 1)).astype(np.float32)\n",
        "\n",
        "xtb, ytb, xeb, yeb = next(ds_iter)\n",
        "plt.figure(figsize=(14, 10))\n",
        "colors = itertools.cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])\n",
        "for xts, yts, xes, yes in zip(xtb, ytb, xeb, yeb):\n",
        "  c_t = next(colors)\n",
        "  c_e = next(colors)\n",
        "  markers = itertools.cycle((',', '+', '.', 'o', '*')) \n",
        "  for xtsib, ytsib, xesib, yesib in zip(xts, yts, xes, yes):\n",
        "    marker = next(markers)\n",
        "    plt.scatter(xtsib, ytsib, c=c_t, marker=marker)\n",
        "    plt.scatter(xesib, yesib, c=c_e, marker=marker)\n",
        "\n",
        "plt.show()\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Aq5JC810rXYF"
      },
      "source": [
        "Create a MP network:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "lD0T9QJFnEaz"
      },
      "outputs": [],
      "source": [
        "\n",
        "# This is the size of the message passed.\n",
        "message_size = 8\n",
        "stateful = True\n",
        "stateful_hidden_n = 15\n",
        "\n",
        "# This network is keras-style initialized.\n",
        "# If you want to create a single layer, you need to pass it also the in_dim\n",
        "# and message size.\n",
        "network = MPNetwork(\n",
        "    [\n",
        "     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n), \n",
        "     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPDense(1, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     ],\n",
        "     MPL2Loss(stateful=stateful, stateful_hidden_n=stateful_hidden_n))\n",
        "network.setup(in_dim=1, message_size=message_size, inner_batch_size=INNER_BATCH_SIZE)\n",
        "\n",
        "# see trainable weights:\n",
        "tr_w = network.get_trainable_weights()\n",
        "print(\"trainable weights:\")\n",
        "tot_w = 0\n",
        "for w in tr_w:\n",
        "  w = w.numpy()\n",
        "  w_size = w.size\n",
        "  tot_w += w_size\n",
        "\n",
        "  print(w.shape, w_size)\n",
        "print(\"tot n:\", tot_w)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "776DRprNBM4p"
      },
      "outputs": [],
      "source": [
        "\n",
        "def create_new_p_fw():\n",
        "  return network.init(INNER_BATCH_SIZE)\n",
        "\n",
        "POOL_SIZE = 128\n",
        "\n",
        "pool = SamplePool(ps=tf.stack(\n",
        "    [network.serialize_pfw(create_new_p_fw()) for _ in range(POOL_SIZE)]).numpy())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Xv5JvNrLaDya"
      },
      "outputs": [],
      "source": [
        "num_steps = tf.constant(NUM_STEPS)\n",
        "\n",
        "learning_schedule = 1e-4\n",
        "\n",
        "# Prepare a training regime.\n",
        "# The heldout_weight tells you how to split the loss between train and eval sets\n",
        "# that are passed to the network.\n",
        "# Empirically, a heldout_weight=0.0 (or None), results in a much lower overall\n",
        "# performance, both for train and test losses.\n",
        "training_regime = TrainingRegime(\n",
        "    network, heldout_weight=1.0, hint_loss_ratio=0.7, remember_loss_ratio=0.0)\n",
        "\n",
        "last_step = 0\n",
        "\n",
        "\n",
        "# minibatch init, allowing to initialize by looking at more\n",
        "# than just one step.\n",
        "# Likewise, this can be run multiple times to improve the initialization.\n",
        "for j in range(1):\n",
        "  print(\"on\", j)\n",
        "  stats = []\n",
        "  pfw = network.init(INNER_BATCH_SIZE)\n",
        "\n",
        "  x_b, y_b, _, _ = next(ds_iter)\n",
        "  x_b, y_b = x_b[0], y_b[0]\n",
        "  for i in range(NUM_STEPS):\n",
        "    pfw, stats_i = network.minibatch_init(x_b[i],  y_b[i], x_b[i].shape[0], pfw=pfw)\n",
        "    stats.append(stats_i)\n",
        "  # update\n",
        "  network.update_statistics(stats, update_perc=1.)\n",
        "\n",
        "  print(\"final mean:\")\n",
        "  for p in tf.nest.flatten(pfw):\n",
        "    print(p.shape, tf.reduce_mean(p), tf.math.reduce_std(p))\n",
        "\n",
        "\n",
        "# The outer loop here uses Adam. SGD/Momentum are more stable but way slower.\n",
        "trainer = tf.keras.optimizers.Adam(learning_schedule)\n",
        "\n",
        "loss_log = []\n",
        "def smoothen(l, lookback=20):\n",
        "  kernel = [1./lookback] * lookback\n",
        "  return np.convolve(l[0:1] * (lookback - 1) + l, kernel, \"valid\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "sWlFhwL5od7q"
      },
      "outputs": [],
      "source": [
        "training_steps = 200000\n",
        "\n",
        "@tf.function\n",
        "def step(pfws, xts, yts, xes, yes, num_steps):\n",
        "  print(\"compiling\")\n",
        "  with tf.GradientTape() as g:\n",
        "    l, _, _ = training_regime.batch_mp_loss(pfws, xts, yts, xes, yes, num_steps)\n",
        "  all_weights = network.get_trainable_weights()\n",
        "  grads = g.gradient(l, all_weights)\n",
        "  # Try grad clipping to avoid explosions.\n",
        "  grads = [g/(tf.norm(g)+1e-8) for g in grads]\n",
        "  trainer.apply_gradients(zip(grads, all_weights))\n",
        "  return l\n",
        "\n",
        "\n",
        "import time\n",
        "start_time = time.time()\n",
        "\n",
        "for i in range(last_step + 1, last_step +1 + training_steps):\n",
        "  last_step = i\n",
        "\n",
        "  tmp_t = time.time()\n",
        "  xts, yts, xes, yes = next(ds_iter)\n",
        "\n",
        "  batch = pool.sample(OUTER_BATCH_SIZE)\n",
        "  fwps = batch.ps\n",
        "\n",
        "  l = step(fwps, xts, yts, xes, yes, num_steps)\n",
        "  loss_log.append(l)\n",
        "\n",
        "  if i % 50 == 0:\n",
        "    print(i)\n",
        "    print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "  if i % 500 == 0:\n",
        "    plt.plot(smoothen(loss_log, 100), label='mp')\n",
        "    plt.yscale('log')\n",
        "    plt.legend()\n",
        "    plt.show()\n",
        "print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "QUmKqMWxqLtI"
      },
      "outputs": [],
      "source": [
        "\n",
        "print(loss_log[-1])\n",
        "plt.plot(smoothen(loss_log, 100), label='mp')\n",
        "plt.yscale('log')\n",
        "plt.ylim(0.0, 1e-1)\n",
        "plt.gca().yaxis.grid(True)\n",
        "plt.legend()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tf4ENAZs-zP5"
      },
      "source": [
        "#Proper evaluation: run 100 different few-shot instances with totally new network params.\n",
        "\n",
        "The train loss is computed only on points that the network has already observed.\n",
        "\n",
        "The eval loss is computed on the entire range [-5, 5]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "SFbFuJ_GDBU5"
      },
      "outputs": [],
      "source": [
        "!mkdir tmp\n",
        "\n",
        "!ls tmp -R"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "NI47xog03QCg"
      },
      "outputs": [],
      "source": [
        "\n",
        "!mkdir tmp\n",
        "file_path = \"tmp/weights\"\n",
        "\n",
        "network.save_weights(file_path, last_step)\n",
        "!ls -lh tmp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "MPJlM6koyeUN"
      },
      "outputs": [],
      "source": [
        "# @title Optionally, load saved model\n",
        "checkpoint_file_path = \"savedmodels/sinusoid_weights\"\n",
        "network = MPNetwork(\n",
        "    [\n",
        "     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n), \n",
        "     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPDense(1, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     ],\n",
        "     MPL2Loss(stateful=stateful, stateful_hidden_n=stateful_hidden_n))\n",
        "network.setup(in_dim=1, message_size=message_size, inner_batch_size=INNER_BATCH_SIZE)\n",
        "network.load_weights(checkpoint_file_path)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Ty5_TQFm8vGu"
      },
      "outputs": [],
      "source": [
        "\n",
        "eval_tot_steps = 100\n",
        "\n",
        "tr_losses = np.zeros([eval_tot_steps, NUM_STEPS])\n",
        "ev_losses = np.zeros([eval_tot_steps, NUM_STEPS + 1]) # also 0-step.\n",
        "\n",
        "@tf.function\n",
        "def get_loss(pfw, x, y):\n",
        "  predictions, _= network.forward(pfw, x)\n",
        "  loss, _ = network.compute_loss(predictions, y)\n",
        "  return loss\n",
        "\n",
        "start_time = time.time()\n",
        "\n",
        "for r in range(eval_tot_steps):\n",
        "  p_fw = network.init(INNER_BATCH_SIZE)\n",
        "\n",
        "  A, ph = ds_factory._create_task()\n",
        "\n",
        "  targets = A * np.sin(xrange_inputs + ph)\n",
        "\n",
        "  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "\n",
        "  # initial loss.\n",
        "  loss = get_loss(p_fw, xrange_inputs, targets)\n",
        "  ev_losses[r, 0] = tf.reduce_mean(loss)\n",
        "\n",
        "  for i in range(NUM_STEPS):\n",
        "    p_fw, _ = network.inner_update(p_fw, xt[i], yt[i])\n",
        "\n",
        "    # loss specific to only what we observe.\n",
        "    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "    loss = get_loss(p_fw, x_observed_so_far, y_observed_so_far)\n",
        "    tr_losses[r, i] = tf.reduce_mean(loss)\n",
        "\n",
        "    # Plotting for the continuous input range\n",
        "    loss = get_loss(p_fw, xrange_inputs, targets)\n",
        "    ev_losses[r, i + 1] = tf.reduce_mean(loss)\n",
        "print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "\n",
        "tr_losses_m = np.mean(tr_losses, axis=0)\n",
        "ev_losses_m = np.mean(ev_losses, axis=0)\n",
        "\n",
        "tr_losses_sd = np.std(tr_losses, axis=0)\n",
        "ev_losses_sd = np.std(ev_losses, axis=0)\n",
        "\n",
        "ub = [m + sd for m, sd in zip(tr_losses_m, tr_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(tr_losses_m, tr_losses_sd)]\n",
        "plt.fill_between(range(1, len(tr_losses_m) + 1), ub, lb, alpha=.5)\n",
        "plt.plot(range(1, len(tr_losses_m) + 1), tr_losses_m, label='train loss')\n",
        "\n",
        "ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "plt.fill_between(range(0, len(ev_losses_m)), ub, lb, alpha=.5)\n",
        "plt.plot(range(0, len(ev_losses_m)), ev_losses_m, label='eval loss')\n",
        "plt.ylim(0.0, 0.025)\n",
        "plt.xlabel(\"num steps\")\n",
        "plt.ylabel(\"L2 loss\")\n",
        "plt.legend()\n",
        "\n",
        "with open(\"tmp/mplp_losses.png\", \"wb\") as fout:\n",
        "  plt.savefig(fout)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "czm5dzijZE2V"
      },
      "outputs": [],
      "source": [
        "print(tr_losses_m, ev_losses_m)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mV-JovFUs1V1"
      },
      "outputs": [],
      "source": [
        "# @title Show example runs:\n",
        "\n",
        "fig, axs = plt.subplots(5, 2, figsize=(10,15))\n",
        "\n",
        "for fig_n in range(5):\n",
        "  p_fw = network.init(INNER_BATCH_SIZE)\n",
        "\n",
        "  n_plot = 5\n",
        "  plot_every = NUM_STEPS // n_plot if NUM_STEPS \u003e= n_plot else NUM_STEPS\n",
        "\n",
        "  predictions, _ = network.forward(p_fw, xrange_inputs)\n",
        "\n",
        "  A, ph = ds_factory._create_task()\n",
        "\n",
        "  targets = A * np.sin(xrange_inputs + ph)\n",
        "  axs[fig_n][0].plot(xrange_inputs, targets, label='target')\n",
        "\n",
        "  predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "  axs[fig_n][0].plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))\n",
        "\n",
        "  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "  xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "  tr_losses = []\n",
        "  ev_losses = []\n",
        "\n",
        "  for i in range(NUM_STEPS):\n",
        "    p_fw, _ = network.inner_update(p_fw, xt[i], yt[i])\n",
        "\n",
        "    # loss specific to only what we observe.\n",
        "    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "    predictions, _= network.forward(p_fw, x_observed_so_far)\n",
        "    loss, _ = network.compute_loss(predictions, y_observed_so_far)\n",
        "    tr_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "    # Plotting for the continuous input range\n",
        "    predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "    if (i+1) % plot_every == 0:\n",
        "      axs[fig_n][0].plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))\n",
        "    loss, _ = network.compute_loss(predictions, targets)\n",
        "    ev_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "  axs[fig_n][1].plot(np.arange(len(tr_losses)), tr_losses, label='tr_losses')\n",
        "  axs[fig_n][1].plot(np.arange(len(ev_losses)), ev_losses, label='ev_losses')\n",
        "\n",
        "axs[0][0].legend()\n",
        "axs[0][1].legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "0Lw7YaJaJfvB"
      },
      "outputs": [],
      "source": [
        "# @title Single run for drawing.\n",
        "\n",
        "p_fw = network.init(INNER_BATCH_SIZE)\n",
        "\n",
        "n_plot = 5\n",
        "plot_every = NUM_STEPS // n_plot if NUM_STEPS \u003e= n_plot else NUM_STEPS\n",
        "\n",
        "predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "\n",
        "A, ph = ds_factory._create_task()\n",
        "\n",
        "targets = A * np.sin(xrange_inputs + ph)\n",
        "plt.plot(xrange_inputs, targets, label='target')\n",
        "\n",
        "predictions, _ = network.forward(p_fw, xrange_inputs)\n",
        "plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))\n",
        "\n",
        "xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "tr_losses = []\n",
        "ev_losses = []\n",
        "\n",
        "for i in range(NUM_STEPS):\n",
        "  p_fw, _ = network.inner_update(p_fw, xt[i], yt[i])\n",
        "\n",
        "  # loss specific to only what we observe.\n",
        "  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "  predictions, _= network.forward(p_fw, x_observed_so_far)\n",
        "  loss, _ = network.compute_loss(predictions, y_observed_so_far)\n",
        "  tr_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "  # Plotting for the continuous input range\n",
        "  predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "  if (i+1) % plot_every == 0:\n",
        "    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))\n",
        "  loss, _ = network.compute_loss(predictions, targets)\n",
        "  ev_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "\n",
        "plt.legend()\n",
        "\n",
        "\n",
        "with open(\"tmp/mplp_example_run.png\", \"wb\") as fout:\n",
        "  plt.savefig(fout)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1ZnWudcy_qUV"
      },
      "source": [
        "# Compare it with an Adam run (the best I managed to make)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "UeGe4aaOKJ6r"
      },
      "outputs": [],
      "source": [
        "adam_network = MPNetwork(\n",
        "      [MPDense(20),\n",
        "       MPActivation(tf.nn.relu),\n",
        "       MPDense(20), \n",
        "       MPActivation(tf.nn.relu),\n",
        "       MPDense(1),\n",
        "      ],\n",
        "      MPL2Loss())\n",
        "adam_network.setup(in_dim=1, message_size=message_size)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "NKMDOGpORYyF"
      },
      "outputs": [],
      "source": [
        "eval_tot_steps = 100\n",
        "\n",
        "tr_losses = np.zeros([eval_tot_steps, NUM_STEPS])\n",
        "ev_losses = np.zeros([eval_tot_steps, NUM_STEPS + 1]) # also 0 step.\n",
        "\n",
        "for r in range(eval_tot_steps):\n",
        "  # We need to transform these into variables.\n",
        "  p_fw = [tf.Variable(t) for t in adam_network.init()]\n",
        "  adam_trainer = tf.keras.optimizers.Adam(0.01)\n",
        "\n",
        "  def adam_step(xt, yt):\n",
        "    with tf.GradientTape() as g:\n",
        "      g.watch(p_fw)\n",
        "      y, _ = adam_network.forward(p_fw, xt)\n",
        "      l, _ = adam_network.compute_loss(y, yt)\n",
        "      l = tf.reduce_mean(l)\n",
        "\n",
        "    grads = g.gradient(l, p_fw)\n",
        "    adam_trainer.apply_gradients(zip(grads, p_fw))\n",
        "\n",
        "    return l\n",
        "\n",
        "  A, ph = ds_factory._create_task()\n",
        "\n",
        "  targets = A * np.sin(xrange_inputs + ph)\n",
        "\n",
        "  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "  xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "\n",
        "  # initial loss.\n",
        "  predictions, _ = adam_network.forward(p_fw, xrange_inputs)\n",
        "  loss, _ = adam_network.compute_loss(predictions, targets)\n",
        "  ev_losses[r, 0] = tf.reduce_mean(loss)\n",
        "\n",
        "  for i in range(NUM_STEPS):\n",
        "    adam_step(xt[i], yt[i])\n",
        "\n",
        "    # loss specific to only what we observe.\n",
        "    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "    predictions, _= adam_network.forward(p_fw, x_observed_so_far)\n",
        "    loss, _ = adam_network.compute_loss(predictions, y_observed_so_far)\n",
        "    tr_losses[r, i] = tf.reduce_mean(loss)\n",
        "\n",
        "    # Plotting for the continuous input range\n",
        "    predictions, _= adam_network.forward(p_fw, xrange_inputs)\n",
        "    loss, _ = adam_network.compute_loss(predictions, targets)\n",
        "    ev_losses[r, i + 1] = tf.reduce_mean(loss)\n",
        "\n",
        "tr_losses_m = np.mean(tr_losses, axis=0)\n",
        "ev_losses_m = np.mean(ev_losses, axis=0)\n",
        "\n",
        "tr_losses_sd = np.std(tr_losses, axis=0)\n",
        "ev_losses_sd = np.std(ev_losses, axis=0)\n",
        "\n",
        "ub = [m + sd for m, sd in zip(tr_losses_m, tr_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(tr_losses_m, tr_losses_sd)]\n",
        "plt.fill_between(range(1, len(tr_losses_m) + 1), ub, lb, alpha=.5)\n",
        "plt.plot(range(1, len(tr_losses_m) + 1), tr_losses_m, label='train loss')\n",
        "\n",
        "ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "plt.fill_between(range(0, len(ev_losses_m)), ub, lb, alpha=.5)\n",
        "plt.plot(range(0, len(ev_losses_m)), ev_losses_m, label='eval loss')\n",
        "plt.ylim(0.0, 0.04)\n",
        "plt.xlabel(\"num steps\")\n",
        "plt.ylabel(\"L2 loss\")\n",
        "plt.legend()\n",
        "\n",
        "with open(\"tmp/adam_losses.png\", \"wb\") as fout:\n",
        "  plt.savefig(fout)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vmJj63OEO35M"
      },
      "outputs": [],
      "source": [
        "# Same task as MPLP fo same drawing\n",
        "\n",
        "p_fw = [tf.Variable(t) for t in adam_network.init()]\n",
        "adam_trainer = tf.keras.optimizers.Adam(0.01)\n",
        "\n",
        "n_plot = 5\n",
        "plot_every = NUM_STEPS // n_plot\n",
        "\n",
        "def adam_step(xt, yt):\n",
        "  with tf.GradientTape() as g:\n",
        "    g.watch(p_fw)\n",
        "    y, _ = adam_network.forward(p_fw, xt)\n",
        "    l, _ = adam_network.compute_loss(y, yt)\n",
        "    l = tf.reduce_mean(l)\n",
        "\n",
        "  grads = g.gradient(l, p_fw)\n",
        "  adam_trainer.apply_gradients(zip(grads, p_fw))\n",
        "\n",
        "  return l\n",
        "\n",
        "predictions, _ = adam_network.forward(p_fw, xrange_inputs)\n",
        "\n",
        "A, ph = ds_factory._create_task()\n",
        "\n",
        "targets = A * np.sin(xrange_inputs + ph)\n",
        "plt.plot(xrange_inputs, targets, label='target')\n",
        "\n",
        "predictions, _= adam_network.forward(p_fw, xrange_inputs)\n",
        "plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))\n",
        "\n",
        "xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "tr_losses = []\n",
        "ev_losses = []\n",
        "\n",
        "for i in range(NUM_STEPS):\n",
        "  adam_step(xt[i], yt[i])\n",
        "\n",
        "  # loss specific to only what we observe.\n",
        "  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "  predictions, _= adam_network.forward(p_fw, x_observed_so_far)\n",
        "  loss, _ = adam_network.compute_loss(predictions, y_observed_so_far)\n",
        "  tr_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "  # Plotting for the continuous input range\n",
        "  predictions, _= adam_network.forward(p_fw, xrange_inputs)\n",
        "  if (i+1) % plot_every == 0:\n",
        "    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))\n",
        "  loss, _ = adam_network.compute_loss(predictions, targets)\n",
        "  ev_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "plt.legend()\n",
        "\n",
        "with open(\"tmp/adam_example_run.png\", \"wb\") as fout:\n",
        "  plt.savefig(fout)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "TdbjW_uplJ4H"
      },
      "outputs": [],
      "source": [
        "# @title Show an example run:\n",
        "p_fw = [tf.Variable(t) for t in adam_network.init()]\n",
        "adam_trainer = tf.keras.optimizers.Adam(0.01)\n",
        "\n",
        "n_plot = 5\n",
        "plot_every = NUM_STEPS // n_plot\n",
        "\n",
        "def adam_step(xt, yt):\n",
        "  with tf.GradientTape() as g:\n",
        "    g.watch(p_fw)\n",
        "    y, _ = adam_network.forward(p_fw, xt)\n",
        "    l, _ = adam_network.compute_loss(y, yt)\n",
        "    l = tf.reduce_mean(l)\n",
        "\n",
        "  grads = g.gradient(l, p_fw)\n",
        "  adam_trainer.apply_gradients(zip(grads, p_fw))\n",
        "\n",
        "  return l\n",
        "\n",
        "predictions, _ = adam_network.forward(p_fw, xrange_inputs)\n",
        "\n",
        "A, ph = ds_factory._create_task()\n",
        "\n",
        "targets = A * np.sin(xrange_inputs + ph)\n",
        "plt.plot(xrange_inputs, targets, label='target')\n",
        "\n",
        "predictions, _= adam_network.forward(p_fw, xrange_inputs)\n",
        "plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))\n",
        "\n",
        "xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "tr_losses = []\n",
        "ev_losses = []\n",
        "\n",
        "for i in range(NUM_STEPS):\n",
        "  adam_step(xt[i], yt[i])\n",
        "\n",
        "  # loss specific to only what we observe.\n",
        "  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "  predictions, _ = adam_network.forward(p_fw, x_observed_so_far)\n",
        "  loss, _ = adam_network.compute_loss(predictions, y_observed_so_far)\n",
        "  tr_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "  # Plotting for the continuous input range\n",
        "  predictions, _ = adam_network.forward(p_fw, xrange_inputs)\n",
        "  if (i+1) % plot_every == 0:\n",
        "    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))\n",
        "  loss, _ = adam_network.compute_loss(predictions, targets)\n",
        "  ev_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "plt.legend()\n",
        "plt.show()\n",
        "\n",
        "plt.plot(np.arange(len(tr_losses)), tr_losses, label='tr_losses')\n",
        "plt.plot(np.arange(len(ev_losses)), ev_losses, label='ev_losses')\n",
        "plt.legend()\n",
        "plt.show()\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/brain/python/client:colab_notebook_py3",
        "kind": "private"
      },
      "name": "mplp_sinusoid_5x10.ipynb",
      "provenance": [
        {
          "file_id": "1zaC-kiZBZymEXWR-OGGsYwLZmzv6-I9w",
          "timestamp": 1593621600719
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/self_organising_systems/mplp/notebooks/Sinusoid_L2L_5x10.ipynb?workspaceId=etr:twp::citc",
          "timestamp": 1593591906528
        },
        {
          "file_id": "13MWaLVCK11RJlFnmuz497CVfT2z_Mnze",
          "timestamp": 1593424830833
        },
        {
          "file_id": "1cb0W55ElZyZyCWcditlo9eLJ7VvmK3KX",
          "timestamp": 1587031457071
        },
        {
          "file_id": "1Id1-A7OS9ytTC52Y0T2S7mJjrwuuQ3jO",
          "timestamp": 1584949948813
        },
        {
          "file_id": "1bdTUOQlh7DvhGmH7lsmlYDC0y3k762jb",
          "timestamp": 1582372543431
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
